#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This file defines the `TrtexecRunner` runner, which takes an ONNX model and
runs inference on the trtexec backend.

The runner implements the standard `BaseRunner` interface.
"""

import subprocess
import json
import tempfile
from collections import OrderedDict
import shutil

from polygraphy import mod, util
from polygraphy.backend.base import BaseRunner
from polygraphy.common import TensorMetadata
from polygraphy.datatype import DataType
from polygraphy.logger import G_LOGGER
from polygraphy.backend.onnx import onnx_from_path
from polygraphy.backend.onnx.util import get_input_metadata
from polygraphy.backend.trt import engine_from_bytes
from polygraphy.backend.common import bytes_from_path

trt = mod.lazy_import("tensorrt>=8.5")
np = mod.lazy_import("numpy")

TRTEXEC_DEFAULT_PATH = "trtexec"
MiB = 1024 ** 2

def which(path):
    """
    Check whether `path` is an executable available on PATH
    """
    return shutil.which(path) is not None

def convert_shape(input_name, trt_spec):
    """
    Converts a trt shape spec to a trtexec shape spec
    """
    trtexec_spec = '{}:{}'.format(input_name, 'x'.join(map(str, trt_spec)))
    return trtexec_spec

def parse_input_shapes(input_shapes):
    """
    Generate necessary specs to pass to --minShapes, --optShapes, --maxShapes
    for trtexec backend
    """
    if not input_shapes:
        return None

    trtexec_input_shapes = []
    for input_name, shape_spec in input_shapes.items():
        trtexec_input_shapes.append(convert_shape(input_name, shape_spec.shape))
    return ','.join(trtexec_input_shapes)

def parse_profile_dicts(profile_dicts):
    """
    Generate necessary specs to pass to --minShapes, --optShapes, --maxShapes
    for trtexec backend
    """
    if not profile_dicts:
        return None, None, None

    trtexec_min_shapes, trtexec_opt_shapes, trtexec_max_shapes = [], [], []

    for input_name, (min_shapes, opt_shapes, max_shapes) in profile_dicts[0].items():
        trtexec_min_shapes.append(convert_shape(input_name, min_shapes))
        trtexec_opt_shapes.append(convert_shape(input_name, opt_shapes))
        trtexec_max_shapes.append(convert_shape(input_name, max_shapes))

    return ','.join(trtexec_min_shapes), ','.join(trtexec_opt_shapes), ','.join(trtexec_max_shapes)

def parse_layer_precisions(layer_precisions):
    """
    Generate necessary specs to pass to --layerPrecisions for trtexec backend
    """
    if not layer_precisions:
        return None

    trtexec_layer_precisions = []

    for layer, precision in layer_precisions.items():
        trtexec_precision = ""
        if precision == "trt.float32":
            trtexec_precision = "fp32"
        elif precision == "trt.float16":
            trtexec_precision = "fp16"
        elif precision == "trt.int32":
            trtexec_precision = "int32"
        elif precision == "trt.int8":
            trtexec_precision = "int8"
        else:
            G_LOGGER.critical(f"Unsupported precision type: {precision}")

        trtexec_layer_precisions.append(f"{layer}:{trtexec_precision}")

    return ','.join(trtexec_layer_precisions)

def get_inference_time(perf_output):
    """
    Reads the output from the performance summary generated by the trtexec
    binary to extract the required performance statistics
    """
    inference_time_stats = {}
    for line in perf_output.split('\n'):
        index = line.find('Latency:')
        if index >= 0:
            stats = line[index + len('Latency:'):].split(',')
            for stat in stats:
                metric, value = stat.split('=')
                value = value.strip().split(' ')[0]
                inference_time_stats[metric.strip()] = float(value)
            return inference_time_stats
    G_LOGGER.critical(f"Could not read inference time for trtexec backend. This "
    "might cause polygraphy to misbehave")

@mod.export()
class TrtexecRunner(BaseRunner):
    """
    Runs inference using custom trtexec. It accepts all ONNX models.
    """

    def __init__(self, model_path, model_type=None,
    trtexec_path=None, use_cuda_graph=None, avg_runs=None, best=None, duration=None, device=None, streams=None, min_timing=None, avg_timing=None, expose_dma=None, no_data_transfers=None, trtexec_warmup=None, trtexec_iterations=None, trtexec_export_times=None, trtexec_export_output=None, trtexec_export_profile=None, trtexec_export_layer_info=None,
    use_spin_wait=None, threads=None, use_managed_memory=None, dump_refit=None, dump_output=None, dump_profile=None, dump_layer_info=None, refit=None, separate_profile_run=None, trtexec_no_builder_cache=None, trtexec_profiling_verbosity=None, layer_output_types=None, use_dla_core=None,
    input_shapes=None, profile_dicts=None, tf32=None, fp16=None, int8=None, allow_gpu_fallback=None, precision_constraints=None, mem_pool_size=None, use_dla=None, layer_precisions=None, plugins=None, save_engine=None):
        super().__init__(prefix="trtexec-runner")
        self.model_path = model_path
        self.model_type = model_type
        self.trtexec_path = util.default(trtexec_path, TRTEXEC_DEFAULT_PATH)
        if not which(self.trtexec_path):
            G_LOGGER.critical(f"trtexec not found in given path: {self.trtexec_path}")

        self.use_cuda_graph = use_cuda_graph
        self.avg_runs = avg_runs
        self.best = best
        self.duration = duration
        self.device = device
        self.streams = streams
        self.min_timing = min_timing
        self.avg_timing = avg_timing
        self.expose_dma = expose_dma
        self.no_data_transfers = no_data_transfers
        self.trtexec_warmup = trtexec_warmup
        self.trtexec_iterations = trtexec_iterations
        self.trtexec_export_times = trtexec_export_times
        self.trtexec_export_output = trtexec_export_output
        self.trtexec_export_profile = trtexec_export_profile
        self.trtexec_export_layer_info = trtexec_export_layer_info

        self.use_spin_wait = use_spin_wait
        self.threads = threads
        self.use_managed_memory = use_managed_memory
        self.dump_refit = dump_refit
        self.dump_output = dump_output
        self.dump_profile = dump_profile
        self.dump_layer_info = dump_layer_info
        self.refit = refit
        self.separate_profile_run = separate_profile_run
        self.trtexec_no_builder_cache = trtexec_no_builder_cache
        self.trtexec_profiling_verbosity = trtexec_profiling_verbosity
        self.layer_output_types = layer_output_types
        self.use_dla_core = use_dla_core
        self.input_shapes = parse_input_shapes(input_shapes)
        self.min_shapes, self.opt_shapes, self.max_shapes = parse_profile_dicts(profile_dicts)
        self.no_tf32 = not tf32
        self.fp16 = fp16
        self.int8 = int8
        self.allow_gpu_fallback = allow_gpu_fallback
        self.precision_constraints = precision_constraints
        self.use_dla = 0 if use_dla else None
        self.plugins = plugins
        self.layer_precisions = parse_layer_precisions(layer_precisions)
        self.save_engine = save_engine
        if mem_pool_size is None:
            self.mem_pool_size = None
        else:
            self.mem_pool_size = ""
            for k, v in mem_pool_size.items():
                v = v / MiB # Convert bytes into MiB
                if int(k) == 0:
                    self.mem_pool_size += f"workspace:{v},"
                elif int(k) == 1:
                    self.mem_pool_size += f"dlaSRAM:{v},"
                elif int(k) == 2:
                    self.mem_pool_size += f"dlaLocalDRAM:{v},"
                elif int(k) == 3:
                    self.mem_pool_size += f"dlaGlobalDRAM:{v},"
                else:
                    pass
            self.mem_pool_size = self.mem_pool_size.rstrip(',')


    def activate_impl(self):
        """
        Initializes the construction of the command that needs to be run using
        the `trtexec` backend.
        """

        self.cmd_args = [self.trtexec_path]
        self.input_files = []

        if self.trtexec_export_output:
            self.export_output_file_handle = open(self.trtexec_export_output, 'w+')
        else:
            self.export_output_file_handle = tempfile.NamedTemporaryFile(delete=False)
        self.export_output_file_name = self.trtexec_export_output or self.export_output_file_handle.name

        model_type_mapping = self.get_model_type_mapping()

        # Mapping the args of polygraphy run to that of trtexec
        init_args_mapping = {
            **model_type_mapping,
            'exportOutput': self.export_output_file_name,

            'useCudaGraph': self.use_cuda_graph,
            'avgRuns': self.avg_runs,
            'best': self.best,
            'duration': self.duration,
            'device': self.device,
            'streams': self.streams,
            'minTiming': self.min_timing,
            'avgTiming': self.avg_timing,
            'exposeDMA': self.expose_dma,
            'noDataTransfers': self.no_data_transfers,
            'warmUp': self.trtexec_warmup,
            'iterations': self.trtexec_iterations,
            'exportTimes': self.trtexec_export_times,
            'exportProfile': self.trtexec_export_profile,
            'exportLayerInfo': self.trtexec_export_layer_info,

            'useSpinWait': self.use_spin_wait,
            'threads': self.threads,
            'useManagedMemory': self.use_managed_memory,
            'dumpRefit': self.dump_refit,
            'dumpOutput': self.dump_output,
            'dumpProfile': self.dump_profile,
            'dumpLayerInfo': self.dump_layer_info,
            'refit': self.refit,
            'separateProfileRun': self.separate_profile_run,
            'noBuilderCache': self.trtexec_no_builder_cache,
            'profilingVerbosity': self.trtexec_profiling_verbosity,
            'layerPrecisions': self.layer_precisions,
            'layerOutputTypes': self.layer_output_types,
            'useDLACore': self.use_dla,

            'shapes': self.input_shapes,
            'minShapes': self.min_shapes,
            'optShapes': self.opt_shapes,
            'maxShapes': self.max_shapes,
            'noTF32': self.no_tf32,
            'fp16': self.fp16,
            'int8': self.int8,
            'allowGPUFallback': self.allow_gpu_fallback,
            'precisionConstraints': self.precision_constraints,
            'memPoolSize': self.mem_pool_size,
            'plugins': self.plugins,
            'saveEngine': self.save_engine,

            'verbose': G_LOGGER.severity <= G_LOGGER.EXTRA_VERBOSE,
        }

        for arg, value in init_args_mapping.items():
            self.add_cmd_args(arg, value)

    def add_cmd_args(self, name, value=None):
        """
        Add the args to `self.cmd_args`. The function handles both
        args and kwargs
        """
        if value is None:
            return

        if isinstance(value, bool):
            # For a bool, add the arg only if the corresponding value is `True`
            if value:
                self.cmd_args.append('--{}'.format(name))
        else:
            self.cmd_args.append('--{}={}'.format(name, value))

    def get_model_type_mapping(self):
        """
        Add the required args based on the model type
        """
        if self.model_type == 'onnx':
            return {
                'onnx': self.model_path
            }

        if self.model_type == 'engine':
            return {
                'loadEngine':self.model_path
            }

        G_LOGGER.critical(f"Unsupported model type: {self.model_type}. `trtexec` only supports TensorRT engines and ONNX models")

    def generate_load_inputs_spec(self, feed_dict):
        """
        Reads the feed_dict metadata input dictionary and generates files to
        pass as command line input to trtexec binary
        """
        load_inputs_spec = []
        for input, values in feed_dict.items():
            input_file = tempfile.NamedTemporaryFile(delete=False)
            values.tofile(input_file.name)
            load_inputs_spec.append('{}:{}'.format(input, input_file.name))
            self.input_files.append(input_file)
        self.load_inputs_spec = ','.join(load_inputs_spec)

    def read_output_file(self):
        """
        Reads the output from the output file generated by the trtexec binary
        """
        outputs = OrderedDict()
        content = json.load(self.export_output_file_handle)
        for entry in content:
            name, dimensions, values = entry['name'], entry['dimensions'], entry['values']
            dimensions = [int(d) for d in dimensions.split('x')]
            outputs[name] = np.array(values).reshape(*dimensions)
        return outputs

    def get_input_metadata_impl(self):
        # Input metadata is used by Polygraphy's default data loader to
        # determine the required shapes and datatypes of the input buffers.
        if self.model_type == 'onnx':
            model = onnx_from_path(self.model_path)
            return get_input_metadata(model.graph)

        if self.model_type =='engine':
            engine = engine_from_bytes(bytes_from_path(self.model_path))
            meta = TensorMetadata()
            for idx in range(engine.num_io_tensors):
                name = engine.get_tensor_name(idx)
                if engine.get_tensor_mode(name) != trt.TensorIOMode.INPUT:
                    continue
                meta.add(name=name, dtype=DataType.from_dtype(engine.get_tensor_dtype(name), "tensorrt"), shape=engine.get_tensor_shape(name))
            return meta


    def infer_impl(self, feed_dict):
        outputs = OrderedDict()

        # Adds other args that need to generated during inference. For example,
        # `feed_dict` is used to generate the args for `loadInputs`
        self.construct_final_cmd(feed_dict)
        G_LOGGER.info(f"The trtexec command being run: {self.cmd_args}")
        perf_output = subprocess.run(self.cmd_args, stdout=subprocess.PIPE, text=True).\
                            stdout
        self.inference_time_stats = get_inference_time(perf_output)
        G_LOGGER.verbose(f"Inference time statistics: {self.inference_time_stats}")

        outputs = self.read_output_file()
        # inference_time_stats records time in 'ms'. However, polygraphy
        # expects time in seconds.
        self.inference_time = self.inference_time_stats['median'] / 1000
        return outputs

    def last_inference_time_stats(self):
        """
        Provides the inference time statistics
        """
        return self.inference_time_stats

    def construct_final_cmd(self, feed_dict):
        """
        Constructs the complete command to run inference on trtexec backend.
        Adds any other args that need to generated during inference.
        """

        self.generate_load_inputs_spec(feed_dict)
        self.add_cmd_args('loadInputs', self.load_inputs_spec)

    def deactivate_impl(self):
        # Close the temporary files that are created. Python automatically
        # deletes temporary files after they are closed
        self.export_output_file_handle.close()
        for input_file in self.input_files:
            input_file.close()
